library(tidyverse)
library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = T)

args <- commandArgs(trailingOnly = TRUE)

if (!any(args == "vb_input")) {
  stop("Provide a vb estimation as input with the keyword vb_input")
} else {
  vb_input_idx <- which(args == "vb_input")

  filename <- args[vb_input_idx + 1]

  stan_data <- readRDS(filename)$data
  vb_fit <- readRDS(filename)$vb_fit

  filename <- gsub("vb", "MAP", filename)

  median_vb_estimates <- vb_fit@model_pars %>%
    (function(x) x[x != "lp__"]) %>%
    lapply(function(parname) {
      summary(vb_fit, pars = parname) %>%
        .$summary %>%
        as_tibble() %>%
        .$`50%` %>%
        list(pp = .) %>%
        setNames(parname)
    }) %>%
    Reduce(c, .)

  Poisson_model_lasso_MAP <- "
data {
  int<lower=0> N;         // number of data points
  int<lower=0> N_municipalities;         // number of data points
  int<lower=0> NbDeath[N];
  int<lower=0> Population[N];
  int<lower=0> Municipality_code[N];
  int<lower=0> covid[N]; // standard error of effect estimates
  real<lower=0> lambda;
}
parameters {
  real<upper=0> log_hc[N_municipalities];
  real<lower=0, upper=1> hcp[N_municipalities];
  real mu;
  real<lower=0> sigma;
}
  transformed parameters {
    real hc[N_municipalities];
    hc = exp(log_hc);
  }
model {
  real lambda_eff[N];

  mu ~ normal(-5, 5);
  sigma ~ normal(1, 20);

  log_hc ~ normal(mu, sigma);
  hcp ~ exponential(lambda);

  for (i in 1:N) {
    lambda_eff[i] = Population[i]*(hc[Municipality_code[i]] + covid[i]*hcp[Municipality_code[i]]);
  }

  NbDeath ~ poisson(lambda_eff);
}
"
  cat(file = "Scripts/Poisson_model_lasso_MAP.stan", Poisson_model_lasso_MAP)

  system.time(m <- stan_model(file = "Scripts/Poisson_model_lasso_MAP.stan"))

  stan_data$lambda <- median_vb_estimates$lambda

  MLE_fit <- optimizing(object = m, data = stan_data, verbose = T, init = median_vb_estimates, iter = 50000)

  print(paste("saving to", filename))

  saveRDS(MLE_fit, file = filename)
}
